{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Quantization of the GPT-2 Small Model with LLM.int8()\n",
        "This notebook is a companion of chapter 5 of the \"Domain Specific LLMs in Action\" book, author Guglielmo Iozzia, [Manning Publications](https://www.manning.com/), 2024.  \n",
        "The code in this notebook is to introduce readers to the quantization of a decoder-only language model, [GPT-2 Small](https://huggingface.co/openai-community/gpt2) using [LLM.int8()](https://arxiv.org/abs/2208.07339). It requires hardware acceleration (GPU).  \n",
        "More details about the code can be found in the related book's chapter."
      ],
      "metadata": {
        "id": "DCgYUlx0DB6j"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Install the missing requirements (HF's Accelerate and Bitsandbytes)."
      ],
      "metadata": {
        "id": "rvWC_w9xOq0Z"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install accelerate bitsandbytes"
      ],
      "metadata": {
        "id": "JMUTSN2ftU2v"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Import the required packages and classes."
      ],
      "metadata": {
        "id": "o7jiwZbrDcFS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer"
      ],
      "metadata": {
        "id": "eIDcBafjDNP8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download the GPT-2 model and associated tokenizer from the HF's Hub and load it to GPU. Finally print the size (in bytes) of the model in memory."
      ],
      "metadata": {
        "id": "1WUf_-2EDoUL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "model_id = 'openai-community/gpt2'\n",
        "model = AutoModelForCausalLM.from_pretrained(model_id,\n",
        "                                             device_map='auto')\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
        "\n",
        "print(f\"Model size: {model.get_memory_footprint():,} bytes\")"
      ],
      "metadata": {
        "id": "L9SqaXHgS28V"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download the GPT-2 model in 8-bit from the HF's Hub and load it to GPU. Finally print the size (in bytes) of the model in memory."
      ],
      "metadata": {
        "id": "FiVrnxg2PCVq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model_int8 = AutoModelForCausalLM.from_pretrained(model_id,\n",
        "                                                device_map='auto',\n",
        "                                             load_in_8bit=True,\n",
        "                                             )\n",
        "\n",
        "print(f\"Model size: {model_int8.get_memory_footprint():,} bytes\")"
      ],
      "metadata": {
        "id": "iqXp9kq3U9mM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Get the original model's and 8-bit model's weights and prepare them for visualization in a histogram chart."
      ],
      "metadata": {
        "id": "yEmTUmM5Pr8g"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "\n",
        "weights = [param.data.clone() for param in model.parameters()]\n",
        "weights = np.concatenate([t.cpu().numpy().flatten() for t in weights])\n",
        "weights_int8 = [param.data.clone() for param in model_int8.parameters()]\n",
        "weights_int8 = np.concatenate([t.cpu().numpy().flatten() for t in weights_int8])"
      ],
      "metadata": {
        "id": "Qw78l_Yar4KW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Using the matplotlib library, plot the distribution of the weights for the original model and the 8-bit version both on the same histogram chart."
      ],
      "metadata": {
        "id": "XJQFjswmE0A0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.ticker as ticker"
      ],
      "metadata": {
        "id": "nFbi5Oxds22n"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Set background style\n",
        "plt.style.use('ggplot')\n",
        "\n",
        "# Create figure and axes\n",
        "fig, axs = plt.subplots(1, figsize=(10,10), dpi=300, sharex=True)\n",
        "\n",
        "# Plot the histograms for original and zero-point weights\n",
        "axs.hist(weights, bins=150, alpha=0.5, label='Original weights', color='blue', range=(-2, 2))\n",
        "axs.hist(weights_int8, bins=150, alpha=0.5, label='LLM.int8() weights', color='yellow', range=(-2, 2))\n",
        "\n",
        "# Add grid\n",
        "axs.grid(True, linestyle='--', alpha=0.6)\n",
        "\n",
        "# Add legend\n",
        "axs.legend()\n",
        "\n",
        "# Add title and labels\n",
        "axs.set_title('Comparison of Original and LLM.int8() Weights', fontsize=16)\n",
        "\n",
        "axs.set_xlabel('Weights', fontsize=14)\n",
        "axs.set_ylabel('Count', fontsize=14)\n",
        "axs.yaxis.set_major_formatter(ticker.EngFormatter()) # Make y-ticks more human readable\n",
        "\n",
        "# Improve font\n",
        "plt.rc('font', size=12)\n",
        "\n",
        "plt.tight_layout()\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "cLn1dKzJs3aK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Define a function to generate text, whatever the model (original or quantized)."
      ],
      "metadata": {
        "id": "Broep3T1FLtb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def generate_text(model, input_text, max_length=100):\n",
        "    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)\n",
        "    output = model.generate(inputs=input_ids,\n",
        "                            max_length=max_length,\n",
        "                            do_sample=True,\n",
        "                            top_k=30,\n",
        "                            pad_token_id=tokenizer.eos_token_id,\n",
        "                            attention_mask=input_ids.new_ones(input_ids.shape))\n",
        "    return tokenizer.decode(output[0], skip_special_tokens=True)"
      ],
      "metadata": {
        "id": "ccynNm1wqCgE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Use the text generation function defined in the previous code cell to generate text with both model version (the original and the 8-bit quantizated one)."
      ],
      "metadata": {
        "id": "E_TbqdZYFTWm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = 'My favourite school subject is'\n",
        "original_text = generate_text(model, prompt)\n",
        "text_int8 = generate_text(model_int8, prompt)\n",
        "\n",
        "print(f\"Original model:\\n{original_text}\")\n",
        "print(f\"LLM.int8() model:\\n{text_int8}\")"
      ],
      "metadata": {
        "id": "_eplky8iVCjf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Define a function to calculate the perplexity score."
      ],
      "metadata": {
        "id": "a5TAMxp5Fivf"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def calculate_perplexity(model, text, device):\n",
        "    encodings = tokenizer(text, return_tensors='pt').to(device)\n",
        "\n",
        "    input_ids = encodings.input_ids\n",
        "    target_ids = input_ids.clone()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        outputs = model(input_ids, labels=target_ids)\n",
        "\n",
        "    neg_log_likelihood = outputs.loss\n",
        "\n",
        "    perplexity = torch.exp(neg_log_likelihood)\n",
        "\n",
        "    return perplexity"
      ],
      "metadata": {
        "id": "iZABZBZ-4M1w"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Calculate the perplexity score for both versions of the model, using the text results previously generated by both."
      ],
      "metadata": {
        "id": "xTB0KIHORbjH"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "perplexity = calculate_perplexity(model, original_text, device)\n",
        "perplexity_int8 = calculate_perplexity(model_int8, text_int8, device)\n",
        "print(f\"Original Perplexity:   {perplexity.item():.2f}\")\n",
        "print(f\"LLM.int8() perplexity: {perplexity_int8.item():.2f}\")"
      ],
      "metadata": {
        "id": "Cf0dEAkFVmp5"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}